#ifndef _RAY_MACHING_H
#define _RAY_MACHING_H

#include <stdint.h>
#include <torch/torch.h>


extern "C" __declspec(dllexport) void near_far_from_aabb(const at::Tensor* rays_o, const at::Tensor* rays_d, const at::Tensor* aabb, const uint32_t N, const float min_near, at::Tensor* nears, at::Tensor* fars);
extern "C" __declspec(dllexport) void sph_from_ray(const at::Tensor* rays_o, const at::Tensor* rays_d, const float radius, const uint32_t N, at::Tensor* coords);
extern "C" __declspec(dllexport) void morton3D(const at::Tensor* coords, const uint32_t N, at::Tensor* indices);
extern "C" __declspec(dllexport) void morton3D_invert(const at::Tensor* indices, const uint32_t N, at::Tensor* coords);
extern "C" __declspec(dllexport) void packbits(const at::Tensor* grid, const uint32_t N, const float density_thresh, at::Tensor* bitfield);

extern "C" __declspec(dllexport) void march_rays_train(const at::Tensor* rays_o, const at::Tensor* rays_d, const at::Tensor* grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor* nears, const at::Tensor* fars, at::Tensor* xyzs, at::Tensor* dirs, at::Tensor* deltas, at::Tensor* rays, at::Tensor* counter, at::Tensor* noises);
extern "C" __declspec(dllexport) void composite_rays_train_forward(const at::Tensor* sigmas, const at::Tensor* rgbs, const at::Tensor* deltas, const at::Tensor* rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor* weights_sum, at::Tensor* depth, at::Tensor* image);
extern "C" __declspec(dllexport) void composite_rays_train_backward(const at::Tensor* grad_weights_sum, const at::Tensor* grad_image, const at::Tensor* sigmas, const at::Tensor* rgbs, const at::Tensor* deltas, const at::Tensor* rays, const at::Tensor* weights_sum, const at::Tensor* image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor* grad_sigmas, at::Tensor* grad_rgbs);

extern "C" __declspec(dllexport) void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor* rays_alive, const at::Tensor* rays_t, const at::Tensor* rays_o, const at::Tensor* rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor* grid, const at::Tensor* nears, const at::Tensor* fars, at::Tensor* xyzs, at::Tensor* dirs, at::Tensor* deltas, at::Tensor* noises);
extern "C" __declspec(dllexport) void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor* rays_alive, at::Tensor* rays_t, at::Tensor* sigmas, at::Tensor* rgbs, at::Tensor* deltas, at::Tensor* weights_sum, at::Tensor* depth, at::Tensor* image);


#endif